import optuna
from torch.optim import Adam
import torch
from Dataset import load
import argparse
from torch.nn.functional import l1_loss, mse_loss
from reform.NeiborEmbDirSchNet import DirSchNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from reform import Utils
import numpy as np
import time
from torch.utils.data import TensorDataset, DataLoader, random_split

EPS = 1e-6
step_plot = -1

ratio_y = 0.01
ratio_dy = 1


def buildModel(**kwargs):
    if "hid_dim_s" not in kwargs:
        kwargs["hid_dim_s"] = kwargs["hid_dim"]
        kwargs["hid_dim_v"] = kwargs["hid_dim"]
    mod = DirSchNet(y_mean=y_mean,
                    y_std=y_std,
                    global_y_mean=global_y_mean,
                    **kwargs)
    print(
        f"numel {sum(p.numel() for p in mod.parameters() if p.requires_grad)}")
    return mod


parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset', type=str, default="benzene")
parser.add_argument('--modname', type=str, default="0")
parser.add_argument('--test', action="store_true")
parser.add_argument('--gemnet_split', action="store_true")
args = parser.parse_args()

modfilename = f"save_mod/{args.dataset}.dirschnet.pt"

device = torch.device("cuda")
dataset = load(args.dataset)
if args.dataset in [
        'benzene', 'uracil', 'naphthalene', 'aspirin', 'salicylic_acid',
        'malonaldehyde', 'ethanol', 'toluene'
]:
    ratio = [950, 50]
else:
    raise NotImplementedError
N = dataset[0].z.shape[0]
global_y_mean = torch.mean(dataset.data.y)
dataset.data.y = (dataset.data.y - global_y_mean).to(torch.float32)
ds = TensorDataset(dataset.data.z.reshape(-1, N),
                   dataset.data.pos.reshape(-1, N, 3),
                   dataset.data.y.reshape(-1, 1),
                   dataset.data.dy.reshape(-1, N, 3))
y_mean = None
y_std = None


def work(lr: float = 1e-3,
         initlr_ratio: float = 1e-1,
         minlr_ratio: float = 1e-3,
         total_step: int = 3000,
         batch_size: int = 32,
         save_model: bool = False,
         do_test: bool = False,
         jump_train: bool = False,
         search_hp: bool = False,
         max_early_stop: int = 500,
         patience: int = 90,
         warmup: int = 30,
         **kwargs):
    global y_mean, y_std, ratio_y, ratio_dy
    if "ratio_y" in kwargs:
        ratio_y = kwargs["ratio_y"]

    NAN_PANITY = 1e1
    if search_hp:
        trn_ds, val_ds, tst_ds = random_split(
            ds, [950, 256, len(ds) - 950 - 256])
    elif args.gemnet_split:
        trn_ds, val_ds, tst_ds = random_split(ds, [1000, 1000, len(ds) - 2000])
    else:
        trn_ds, val_ds, tst_ds = random_split(ds, [950, 50, len(ds) - 1000])
    val_d = next(
        iter(DataLoader(val_ds, batch_size=len(val_ds), shuffle=False)))
    val_d = [_.to(device) for _ in val_d]
    trn_d = next(
        iter(DataLoader(trn_ds, batch_size=len(trn_ds), shuffle=False)))
    trn_d = [_.to(device) for _ in trn_d]
    trn_dl = Utils.tensorDataloader(trn_d, batch_size, True, device)
    y_mean = torch.mean(trn_d[2]).item()
    y_std = torch.std(trn_d[2]).item()
    mod = buildModel(**kwargs).to(device)
    best_val_loss = float("inf")
    if not jump_train:
        opt = Adam(mod.parameters(),
                   lr=lr * initlr_ratio if warmup > 0 else lr)
        scd1 = StepLR(opt,
                      1,
                      gamma=(1 / initlr_ratio)**(1 / (warmup *
                                                      (950 // batch_size)))
                      if warmup > 0 else 1)
        scd = ReduceLROnPlateau(opt,
                                "min",
                                0.8,
                                patience=patience,
                                min_lr=lr * minlr_ratio,
                                threshold=0.0001)
        early_stop = 0
        for epoch in range(total_step):
            curlr = opt.param_groups[0]["lr"]
            trn_losss = [[], []]

            for batch in trn_dl:
                trn_loss_y, trn_loss_dy = Utils.train_grad(batch, opt, mod, mse_loss, ratio_y, ratio_dy)
                if np.isnan(trn_loss_dy):
                    return NAN_PANITY
                trn_losss[0].append(trn_loss_y)
                trn_losss[1].append(trn_loss_dy)
                if epoch < warmup:
                    scd1.step()
            trn_loss_y = np.average(trn_losss[0])
            trn_loss_dy = np.average(trn_losss[1])
            val_loss_y, val_loss_dy = Utils.test_grad(val_d, mod, l1_loss)
            val_loss = 0.1 * val_loss_y + val_loss_dy
            early_stop += 1
            scd.step(val_loss)
            if np.isnan(val_loss):
                return NAN_PANITY
            if val_loss < best_val_loss:
                early_stop = 0
                best_val_loss = val_loss
                if save_model:
                    torch.save(mod.state_dict(), modfilename)
            if early_stop > max_early_stop:
                break
            print(
                f"iter {epoch} lr {curlr:.4e} trn E {trn_loss_y:.4f} F {trn_loss_dy:.4f} val E {val_loss_y:.4f} F {val_loss_dy:.4f} "
            )
            if epoch % 10 == 0:
                print("", end="", flush=True)
            if trn_loss_dy > 1000:
                return min(best_val_loss, NAN_PANITY)

    if do_test:
        mod.load_state_dict(torch.load(modfilename, map_location="cpu"))
        mod = mod.to(device)
        tst_dl = DataLoader(tst_ds, 1024)
        tst_score = []
        num_mol = []
        for batch in tst_dl:
            num_mol.append(batch[0].shape[0])
            batch = tuple(_.to(device) for _ in batch)
            tst_score.append(Utils.test_grad(batch, mod, l1_loss))
        num_mol = np.array(num_mol)
        tst_score = np.array(tst_score)
        tst_score = np.sum(tst_score *
                           (num_mol.reshape(-1, 1) / num_mol.sum()),
                           axis=0)
        trn_score = Utils.test_grad(trn_d, mod, l1_loss)
        val_score = Utils.test_grad(val_d, mod, l1_loss)
        print(trn_score, val_score, tst_score)
    return min(best_val_loss, NAN_PANITY)


best_fixed_p = {
    'max_z': 20,
    'wd': 0,
    'lin1_tailact': True,
    'ln_lin1': True,
    'ln_s2v': True,
    'batch_size': 16,
    'lr': 0.001,
    'warmup': 100,
    'rbf': 'nexpnorm',
    'patience': 120,
    'initlr_ratio': 0.01,
    'ef_dim': 32,
    'hid_dim': 256,
    'minlr_ratio': 0.01,
    'rbound_lower': 0.0
}


def search(trial: optuna.Trial):
    kwargs = best_fixed_p.copy()
    kwargs['ratio_y'] = trial.suggest_float("ratio_y", 1e-3, 3e-2, step=1e-3)
    kwargs['cutoff'] = trial.suggest_float("cutoff", 4, 8, step=0.5)
    kwargs['num_mplayer'] = trial.suggest_int("num_mplayer", 2, 6)
    kwargs['ev_decay'] = trial.suggest_categorical("ev_decay", [True, False])
    kwargs['ef_decay'] = trial.suggest_categorical("ef_decay", [True, False])
    kwargs['use_dir2'] = trial.suggest_categorical("use_dir2", [True, False])
    kwargs['dir2mask_tailact'] = trial.suggest_categorical(
        "dir2mask_tailact", [True, False])
    kwargs['ef2mask_tailact'] = trial.suggest_categorical(
        "ef2mask_tailact", [True, False])
    kwargs['ln_emb'] = trial.suggest_categorical("ln_emb", [True, False])
    kwargs['add_ef2dir'] = trial.suggest_categorical("add_ef2dir",
                                                     [True, False])
    ret = work(**kwargs, total_step=4000, max_early_stop=500, search_hp=True)
    print("", flush=True)
    return ret


study = optuna.create_study(direction="minimize",
                            storage="sqlite:///" + "Opt/" +
                            args.dataset + ".db",
                            study_name=args.dataset,
                            load_if_exists=True)
study.optimize(search, n_trials=200)
print("best params ", study.best_params)
print("best valf1 ", study.best_value)
